{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "578ca6e3-9d4c-4291-8884-c1853e13eda6",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-06-18T13:41:05.963316Z",
     "iopub.status.busy": "2024-06-18T13:41:05.962683Z",
     "iopub.status.idle": "2024-06-18T13:41:07.106127Z",
     "shell.execute_reply": "2024-06-18T13:41:07.104849Z",
     "shell.execute_reply.started": "2024-06-18T13:41:05.963294Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x7ff9ac026c70>"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "torch.manual_seed(42)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6075bc3f-5390-4477-b96c-2014b9d48fec",
   "metadata": {},
   "source": [
    "## leaf node ?"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a524b09b-536e-4ad6-bebc-fa83ed8b8e76",
   "metadata": {},
   "source": [
    "- 一个 Tensor/Parameter/Weight 如果是 leaf node\n",
    "    - 其 parts/slice（`W[2,3]`, `W[:2, :3]`） 就不可能再是 leaf node\n",
    "    - 而修改一个 tensor 的 require_grad 值，需要这个 tensor 是 leaf node\n",
    "    - 因此，不能对一个 require_grad 为 True 的 tensor 将其部分置为 False\n",
    "- 那么如何实现对一个 tensor 的部分进行求导呢？"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "4342d24b-431f-452d-b6a7-4583160b2799",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-06-18T13:41:07.107091Z",
     "iopub.status.busy": "2024-06-18T13:41:07.106805Z",
     "iopub.status.idle": "2024-06-18T13:41:07.117021Z",
     "shell.execute_reply": "2024-06-18T13:41:07.115731Z",
     "shell.execute_reply.started": "2024-06-18T13:41:07.107072Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 1.9269,  1.4873,  0.9007, -2.1055],\n",
       "        [ 0.6784, -1.2345, -0.0431, -1.6047],\n",
       "        [-0.7521,  1.6487, -0.3925, -1.4036],\n",
       "        [-0.7279, -0.5594, -0.7688,  0.7624]], requires_grad=True)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "W = torch.randn(4, 4, requires_grad=True)\n",
    "W"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "c2ad487a-6b35-4aff-be9c-67f3cac136ca",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-06-18T13:41:07.118335Z",
     "iopub.status.busy": "2024-06-18T13:41:07.118130Z",
     "iopub.status.idle": "2024-06-18T13:41:07.133925Z",
     "shell.execute_reply": "2024-06-18T13:41:07.132707Z",
     "shell.execute_reply.started": "2024-06-18T13:41:07.118319Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "W.is_leaf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "93185884-2e34-44fa-bed9-bf9922ec5c74",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-06-18T13:41:07.135261Z",
     "iopub.status.busy": "2024-06-18T13:41:07.135059Z",
     "iopub.status.idle": "2024-06-18T13:41:07.144922Z",
     "shell.execute_reply": "2024-06-18T13:41:07.143733Z",
     "shell.execute_reply.started": "2024-06-18T13:41:07.135246Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "False"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "W[2, 3].is_leaf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "8a9186bf-cab6-4bd9-9450-81f22599a2a5",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-06-18T13:41:07.145694Z",
     "iopub.status.busy": "2024-06-18T13:41:07.145499Z",
     "iopub.status.idle": "2024-06-18T13:41:07.157536Z",
     "shell.execute_reply": "2024-06-18T13:41:07.155899Z",
     "shell.execute_reply.started": "2024-06-18T13:41:07.145679Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "False"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "W[:2, :3].is_leaf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "48ac054e-940b-4b94-acb4-280a413e1e1f",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-06-18T13:41:07.160138Z",
     "iopub.status.busy": "2024-06-18T13:41:07.159897Z",
     "iopub.status.idle": "2024-06-18T13:41:07.637975Z",
     "shell.execute_reply": "2024-06-18T13:41:07.636006Z",
     "shell.execute_reply.started": "2024-06-18T13:41:07.160121Z"
    }
   },
   "outputs": [
    {
     "ename": "RuntimeError",
     "evalue": "you can only change requires_grad flags of leaf variables. If you want to use a computed variable in a subgraph that doesn't require differentiation use var_no_grad = var.detach().",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[6], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mW\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m:\u001b[49m\u001b[38;5;241;43m3\u001b[39;49m\u001b[43m]\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrequires_grad_\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m)\u001b[49m\n",
      "\u001b[0;31mRuntimeError\u001b[0m: you can only change requires_grad flags of leaf variables. If you want to use a computed variable in a subgraph that doesn't require differentiation use var_no_grad = var.detach()."
     ]
    }
   ],
   "source": [
    "W[:2, :3].requires_grad_(False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "df452adc-5305-4f11-ba59-84c5f4dfd7ae",
   "metadata": {},
   "source": [
    "## Mask"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "51102b57-80d2-4eb2-8673-c31d0d8ff876",
   "metadata": {},
   "source": [
    "- 那么如何实现对一个 tensor 的部分进行求导呢？\n",
    "    - Hadamard product（按位乘）一个 Mask 矩阵\n",
    "    - `W*M`: 是按位乘，而不是矩阵乘法（`W@M`，才是矩阵乘法）\n",
    "- 后边我们可以验证的是，前向过程中对 Weight 进行 Mask，与先 backward，再对 grad 进行 Mask 效果是一致的"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "401d9854-b59a-4dbc-be08-e3bfa898ca9e",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-06-18T13:41:15.240921Z",
     "iopub.status.busy": "2024-06-18T13:41:15.240611Z",
     "iopub.status.idle": "2024-06-18T13:41:15.252369Z",
     "shell.execute_reply": "2024-06-18T13:41:15.250256Z",
     "shell.execute_reply.started": "2024-06-18T13:41:15.240901Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 1.9269,  1.4873,  0.9007, -2.1055],\n",
       "        [ 0.6784, -1.2345, -0.0431, -1.6047],\n",
       "        [-0.7521,  1.6487, -0.3925, -1.4036],\n",
       "        [-0.7279, -0.5594, -0.7688,  0.7624]], requires_grad=True)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "W"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "507235b6-d1e8-4873-84ae-ecdf41bcebd5",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-06-18T13:41:17.014732Z",
     "iopub.status.busy": "2024-06-18T13:41:17.014168Z",
     "iopub.status.idle": "2024-06-18T13:41:17.028212Z",
     "shell.execute_reply": "2024-06-18T13:41:17.026124Z",
     "shell.execute_reply.started": "2024-06-18T13:41:17.014689Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0., 0., 1., 0.],\n",
       "        [1., 1., 1., 0.],\n",
       "        [1., 1., 1., 1.],\n",
       "        [0., 1., 0., 1.]])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "M = torch.bernoulli(torch.full(W.shape, 0.5))\n",
    "M"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0e59400a-bc29-4b11-aa7a-76bac118eddd",
   "metadata": {},
   "source": [
    "### 前向对 Weight 进行 Mask"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8c2fd754-715f-4686-87cf-24a103edbc53",
   "metadata": {},
   "source": [
    "$$\n",
    "\\begin{split}\n",
    "&y=Wx, L=f(y)\\\\\n",
    "&\\frac{\\partial L}{\\partial W}=\\frac{\\partial L}{\\partial y}\\frac{\\partial y}{\\partial W}=\\frac{\\partial L}{\\partial y}x^T\\\\\n",
    "\\end{split}\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "f0b22e7d-59ac-4a55-a70a-e52651ec5a19",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-06-18T13:41:18.942560Z",
     "iopub.status.busy": "2024-06-18T13:41:18.941979Z",
     "iopub.status.idle": "2024-06-18T13:41:18.955143Z",
     "shell.execute_reply": "2024-06-18T13:41:18.953090Z",
     "shell.execute_reply.started": "2024-06-18T13:41:18.942516Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 1.3221,  0.8172, -0.7658, -0.7506])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = torch.randn(4)\n",
    "x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "d4101f7d-ea51-4709-8d1f-2eb1c13e7706",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-06-18T13:41:35.167935Z",
     "iopub.status.busy": "2024-06-18T13:41:35.167279Z",
     "iopub.status.idle": "2024-06-18T13:41:35.182369Z",
     "shell.execute_reply": "2024-06-18T13:41:35.180464Z",
     "shell.execute_reply.started": "2024-06-18T13:41:35.167888Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.0000,  0.0000,  0.9007, -0.0000],\n",
       "        [ 0.6784, -1.2345, -0.0431, -0.0000],\n",
       "        [-0.7521,  1.6487, -0.3925, -1.4036],\n",
       "        [-0.0000, -0.5594, -0.0000,  0.7624]], grad_fn=<MulBackward0>)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "W_m = W*M\n",
    "W_m"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "5bf9bc0f-ff5b-4cad-a790-e3fb56053781",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-06-18T13:42:22.508226Z",
     "iopub.status.busy": "2024-06-18T13:42:22.507553Z",
     "iopub.status.idle": "2024-06-18T13:42:22.588753Z",
     "shell.execute_reply": "2024-06-18T13:42:22.587477Z",
     "shell.execute_reply.started": "2024-06-18T13:42:22.508178Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.0000,  0.0000, -0.7658, -0.0000],\n",
       "        [ 1.3221,  0.8172, -0.7658, -0.0000],\n",
       "        [ 1.3221,  0.8172, -0.7658, -0.7506],\n",
       "        [ 0.0000,  0.8172, -0.0000, -0.7506]])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y = W_m @ x\n",
    "y.sum().backward()\n",
    "W.grad"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6f03e821-1660-4948-aec2-01694f8e08d9",
   "metadata": {},
   "source": [
    "### 先 backward，对 Grad 进行 Mask "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "8a236773-1e95-4592-af8e-dc19a74be4f3",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-06-18T13:42:39.384079Z",
     "iopub.status.busy": "2024-06-18T13:42:39.383617Z",
     "iopub.status.idle": "2024-06-18T13:42:39.402158Z",
     "shell.execute_reply": "2024-06-18T13:42:39.400122Z",
     "shell.execute_reply.started": "2024-06-18T13:42:39.384048Z"
    },
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0.]])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "W.grad.zero_()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "2f652197-772e-41e1-84cc-8bf684b4a09a",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-06-18T13:42:57.592565Z",
     "iopub.status.busy": "2024-06-18T13:42:57.591950Z",
     "iopub.status.idle": "2024-06-18T13:42:57.611241Z",
     "shell.execute_reply": "2024-06-18T13:42:57.609242Z",
     "shell.execute_reply.started": "2024-06-18T13:42:57.592520Z"
    }
   },
   "outputs": [],
   "source": [
    "y = W @ x\n",
    "y.sum().backward()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "6d2b021a-a188-4867-a610-9e49a74bb8b2",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-06-18T13:43:14.779863Z",
     "iopub.status.busy": "2024-06-18T13:43:14.779191Z",
     "iopub.status.idle": "2024-06-18T13:43:14.794097Z",
     "shell.execute_reply": "2024-06-18T13:43:14.792021Z",
     "shell.execute_reply.started": "2024-06-18T13:43:14.779815Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.0000,  0.0000, -0.7658, -0.0000],\n",
       "        [ 1.3221,  0.8172, -0.7658, -0.0000],\n",
       "        [ 1.3221,  0.8172, -0.7658, -0.7506],\n",
       "        [ 0.0000,  0.8172, -0.0000, -0.7506]])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "W.grad * M"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d0d9c58-535d-4d62-9c56-a44ffe914858",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
